import sys

sys.path.append("../modules")

import pickle
import gc
import sys
import utils

from tree_parsing import treeparsing_utils as parseutils
from datasets import load_dataset
from tqdm import tqdm


def filter(croot, lang):
    if lang == "php":
        bpe_cnt, nonbpe_cnt = parseutils.count_nodes(croot)
        if nonbpe_cnt <= 5:
            return True
    return False


MAX_SIZE = 100_000
AUTH_TOKEN = None  # TODO: Add token

ALL_LANGS = ["python", "java", "c", "c_sharp", "php", "go", "javascript", "ruby"]
lang = sys.argv[1]
assert lang in ALL_LANGS
LANGS = [lang]

HF_LANGMAP = {"c_sharp": "c-sharp"}

traversal_type = "preorder_dfs_nodeleaf_toks"
nsamples = int(1e6)  # use only 1 million files per language to build vocab

tokendict = {lang: {"bpe": set(), "nonbpe": set()} for lang in LANGS}

counts = {lang: {"+": 0, "-": 0, "#": 0} for lang in LANGS}

for lang in tqdm(LANGS, desc="Language"):
    hf_lang = HF_LANGMAP.get(lang, lang)
    ds = load_dataset(
        "bigcode/the-stack-dedup", data_dir=f"data/{hf_lang}", split="train", streaming=True, use_auth_token=AUTH_TOKEN
    )

    # ds_tok = ds.take(nsamples)

    for idx, sample in tqdm(enumerate(iter(ds)), total=nsamples):
        hexsha = sample["hexsha"]
        content = sample["content"]
        size = sample["size"]

        if size > MAX_SIZE or utils.skip_sample(hexsha):
            counts[lang]["#"] += 1
            continue

        if counts[lang].get("+", 0) >= nsamples:
            break

        try:
            root = parseutils.create_TS_tree(content, lang, False)
            root_custom = parseutils.create_custom_tree(root, lang, False)

            if filter(root_custom, lang):
                raise AssertionError

            tokendict[lang], _ = parseutils.create_tokens_dict(
                root_custom, tokendict[lang], {}, None, traversal_type, False
            )
        except Exception as err:
            counts[lang]["-"] += 1
        else:
            counts[lang]["+"] += 1

    with open(f"../artifacts/tokenizer/tokendict/token_dict_{lang}.pkl", "wb") as f:
        pickle.dump(tokendict, f)

    with open(f"../artifacts/tokenizer/tokendict/filecounts_{lang}.pkl", "wb") as f:
        pickle.dump(counts, f)


for lang, tokens in tokendict.items():
    tokendict[lang]["nonbpe"] = utils.clean_nonbpe_tokens(tokendict[lang]["nonbpe"], lang)


for lang in LANGS:
    print(lang)
    for k, v in tokendict[lang].items():
        print("\t", k, len(v))


with open(f"../artifacts/tokenizer/tokendict/token_dict_{lang}.pkl", "wb") as f:
    pickle.dump(tokendict, f)

with open(f"../artifacts/tokenizer/tokendict/filecounts_{lang}.pkl", "wb") as f:
    pickle.dump(counts, f)
